package vulnerability

import (
	"bufio"
	"context"
	"io/ioutil"
	"os"
	"sort"
	"strings"

	"github.com/google/wire"
	"github.com/open-policy-agent/opa/rego"
	"golang.org/x/xerrors"

	"github.com/aquasecurity/trivy-db/pkg/db"
	dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
	"github.com/aquasecurity/trivy-db/pkg/vulnsrc/vulnerability"
	"github.com/aquasecurity/trivy/pkg/log"
	"github.com/aquasecurity/trivy/pkg/types"
	"github.com/aquasecurity/trivy/pkg/utils"
)

const (
	// DefaultIgnoreFile is the file name to be ignored
	DefaultIgnoreFile = ".trivyignore"
)

var (
	primaryURLPrefixes = map[string][]string{
		vulnerability.Debian:           {"http://www.debian.org", "https://www.debian.org"},
		vulnerability.Ubuntu:           {"http://www.ubuntu.com", "https://usn.ubuntu.com"},
		vulnerability.RedHat:           {"https://access.redhat.com"},
		vulnerability.OpenSuseCVRF:     {"http://lists.opensuse.org", "https://lists.opensuse.org"},
		vulnerability.SuseCVRF:         {"http://lists.opensuse.org", "https://lists.opensuse.org"},
		vulnerability.OracleOVAL:       {"http://linux.oracle.com/errata", "https://linux.oracle.com/errata"},
		vulnerability.NodejsSecurityWg: {"https://www.npmjs.com", "https://hackerone.com"},
		vulnerability.RubySec:          {"https://groups.google.com"},
	}
)

// SuperSet binds the dependencies
var SuperSet = wire.NewSet(
	wire.Struct(new(db.Config)),
	NewClient,
	wire.Bind(new(Operation), new(Client)),
)

// Operation defines the vulnerability operations
type Operation interface {
	FillInfo(vulns []types.DetectedVulnerability, reportType string)
	Filter(ctx context.Context, vulns []types.DetectedVulnerability, severities []dbTypes.Severity,
		ignoreUnfixed bool, ignoreFile string, policy string) ([]types.DetectedVulnerability, error)
}

// Client implements db operations
type Client struct {
	dbc db.Operation
}

// NewClient is the factory method for Client
func NewClient(dbc db.Config) Client {
	return Client{dbc: dbc}
}

// FillInfo fills extra info in vulnerability objects
func (c Client) FillInfo(vulns []types.DetectedVulnerability, reportType string) {
	var err error

	for i := range vulns {
		vulns[i].Vulnerability, err = c.dbc.GetVulnerability(vulns[i].VulnerabilityID)
		if err != nil {
			log.Logger.Warnf("Error while getting vulnerability details: %s\n", err)
			continue
		}

		var source string
		switch reportType {
		case vulnerability.Ubuntu, vulnerability.Alpine, vulnerability.RedHat, vulnerability.RedHatOVAL, vulnerability.Debian, vulnerability.DebianOVAL, vulnerability.Fedora, vulnerability.Amazon, vulnerability.OracleOVAL, vulnerability.SuseCVRF, vulnerability.OpenSuseCVRF, vulnerability.Photon:
			source = reportType
		case vulnerability.CentOS: // CentOS doesn't have its own so we use RedHat
			source = vulnerability.RedHat
		case "npm", "yarn":
			source = vulnerability.NodejsSecurityWg
		case "pipenv", "poetry":
			source = vulnerability.PythonSafetyDB
		case "bundler":
			source = vulnerability.RubySec
		case "cargo":
			source = vulnerability.RustSec
		case "composer":
			source = vulnerability.PhpSecurityAdvisories
		}

		vulns[i].Severity, vulns[i].SeveritySource = c.getVendorSeverity(&vulns[i], source)
		vulns[i].PrimaryURL = c.getPrimaryURL(vulns[i].VulnerabilityID, vulns[i].References, source)
		vulns[i].Vulnerability.VendorSeverity = nil // Remove VendorSeverity from Results
	}
}

func (c Client) getVendorSeverity(vuln *types.DetectedVulnerability, source string) (string, string) {
	if vs, ok := vuln.VendorSeverity[source]; ok {
		return vs.String(), source
	}

	// Try NVD as a fallback if it exists
	if vs, ok := vuln.VendorSeverity[vulnerability.Nvd]; ok {
		return vs.String(), vulnerability.Nvd
	}

	if vuln.Severity == "" {
		return dbTypes.SeverityUnknown.String(), ""
	}

	return vuln.Severity, ""
}

func (c Client) getPrimaryURL(vulnID string, refs []string, source string) string {
	switch {
	case strings.HasPrefix(vulnID, "CVE-"):
		return "https://avd.aquasec.com/nvd/" + strings.ToLower(vulnID)
	case strings.HasPrefix(vulnID, "RUSTSEC-"):
		return "https://rustsec.org/advisories/" + vulnID
	case strings.HasPrefix(vulnID, "GHSA-"):
		return "https://github.com/advisories/" + vulnID
	case strings.HasPrefix(vulnID, "TEMP-"):
		return "https://security-tracker.debian.org/tracker/" + vulnID
	}

	prefixes := primaryURLPrefixes[source]
	for _, pre := range prefixes {
		for _, ref := range refs {
			if strings.HasPrefix(ref, pre) {
				return ref
			}
		}
	}
	return ""
}

// Filter filter out the vulnerabilities
func (c Client) Filter(ctx context.Context, vulns []types.DetectedVulnerability, severities []dbTypes.Severity,
	ignoreUnfixed bool, ignoreFile string, policyFile string) ([]types.DetectedVulnerability, error) {
	ignoredIDs := getIgnoredIDs(ignoreFile)
	var vulnerabilities []types.DetectedVulnerability
	for _, vuln := range vulns {
		// Filter vulnerabilities by severity
		for _, s := range severities {
			if s.String() == vuln.Severity {
				// Ignore unfixed vulnerabilities
				if ignoreUnfixed && vuln.FixedVersion == "" {
					continue
				} else if utils.StringInSlice(vuln.VulnerabilityID, ignoredIDs) {
					continue
				}
				vulnerabilities = append(vulnerabilities, vuln)
				break
			}
		}
	}

	if policyFile != "" {
		var err error
		vulnerabilities, err = applyPolicy(ctx, vulnerabilities, policyFile)
		if err != nil {
			return nil, xerrors.Errorf("failed to apply the policy: %w", err)
		}
	}
	sort.Sort(types.BySeverity(vulnerabilities))
	return vulnerabilities, nil
}

func applyPolicy(ctx context.Context, vulns []types.DetectedVulnerability, policyFile string) ([]types.DetectedVulnerability, error) {
	policy, err := ioutil.ReadFile(policyFile)
	if err != nil {
		return nil, xerrors.Errorf("unable to read the policy file: %w", err)
	}

	query, err := rego.New(
		rego.Query("data.trivy.ignore"),
		rego.Module("lib.rego", module),
		rego.Module("trivy.rego", string(policy)),
	).PrepareForEval(ctx)
	if err != nil {
		return nil, xerrors.Errorf("unable to prepare for eval: %w", err)
	}

	var filtered []types.DetectedVulnerability
	for _, vuln := range vulns {
		results, err := query.Eval(ctx, rego.EvalInput(vuln))
		if err != nil {
			return nil, xerrors.Errorf("unable to evaluate the policy: %w", err)
		} else if len(results) == 0 {
			// Handle undefined result.
			filtered = append(filtered, vuln)
			continue
		}
		ignore, ok := results[0].Expressions[0].Value.(bool)
		if !ok {
			// Handle unexpected result type.
			return nil, xerrors.New("the policy must return boolean")
		}
		if ignore {
			continue
		}
		filtered = append(filtered, vuln)
	}
	return filtered, nil
}

func getIgnoredIDs(ignoreFile string) []string {
	f, err := os.Open(ignoreFile)
	if err != nil {
		// trivy must work even if no .trivyignore exist
		return nil
	}

	var ignoredIDs []string
	scanner := bufio.NewScanner(f)
	for scanner.Scan() {
		line := scanner.Text()
		line = strings.TrimSpace(line)
		if strings.HasPrefix(line, "#") || line == "" {
			continue
		}
		ignoredIDs = append(ignoredIDs, line)
	}
	return ignoredIDs
}
